# Copyright (c) 2012 Ben Reynwar
# Released under MIT License (see LICENSE.txt)

"""
MyHDL Test Bench to check the vericode FFT.
"""

import os
import random
import unittest
import logging

from numpy import fft

from fpga_sdrlib.conversions import c_to_int, cs_to_int, int_to_c, int_to_cs
from fpga_sdrlib.testbench import TestBenchIcarus
from fpga_sdrlib.fft.build import generate_dit_executable
from fpga_sdrlib import config

logger = logging.getLogger(__name__)

class DITTestBenchIcarus(TestBenchIcarus):
    """
    Helper class for doing testing.
    
    Args:
        name: A name to use with for generated files.
        fft_length: The fft length (must be a power of 2)
        in_samples: A list of complex points to send.
        sendnth: Send an input on every `sendnth` clock cycle.
        in_ms: A list of the meta data to send.
        defines: Macro definitions (constants) to use in verilog code.
    """

    extra_signal_names = ['first']

    def __init__(self, name, fft_length, in_samples, sendnth=config.default_sendnth,
                 in_ms=None, start_msgs=None, defines=config.default_defines):
        super(DITTestBenchIcarus, self).__init__(name, in_samples, sendnth,
                                                 in_ms, defines=defines)
        self.fft_length = fft_length
        
    def prepare(self):
        self.executable = generate_dit_executable(self.name, self.fft_length, self.defines)

class TestFFT(unittest.TestCase):

    def setUp(self):
        rg = random.Random(0)
        self.myrand = rg.random
        self.myrandint = rg.randint

    def test_sixteen(self):
        width = 16
        nlog2 = 4
        # Number of FFT to perform
        N_data_sets = 4
        # How often to send input.
        # For large FFTs this must be larger since the speed scales as NlogN.
        # Otherwise we get an overflow error.
        sendnth = 10
        self.random_template(nlog2, width, N_data_sets, sendnth)

    def test_four(self):
        width = 16
        nlog2 = 2
        N_data_sets = 4
        sendnth = 2
        self.random_template(nlog2, width, N_data_sets, sendnth)

    def test_overflow(self):
        width = 16
        nlog2 = 4
        N_data_sets = 4
        sendnth = 1
        # Check that this raises an overflow error
        self.assertRaises(StandardError, self.random_template, (nlog2, width, N_data_sets, sendnth))

    def random_template(self, nlog2, width, N_data_sets, sendnth):
        """
        Test the DUT with a random complex stream.
        """
        N = pow(2, nlog2)
        # Approx many steps we'll need.
        steps_rqd = 2*N_data_sets*int(40.0 / 8 / 3 * nlog2 * N)
        # Generate some random input.
        data_sets = []
        data = []
        for i in range(0, N_data_sets):
            nd = [self.myrand()*2-1 + self.myrand()*2j-1jfor x in range(N)]
            data_sets.append(nd)
            data += nd
        mwidth = 3
        ms = [self.myrandint(0, pow(2, mwidth)-1) for d in data]
        # Create, setup and simulate the test bench.
        defines = config.updated_defines({"DEBUG": False,
                                          "WIDTH": width,
                                          "MWIDTH": mwidth})
        tb = DITTestBenchIcarus('standard', N, data, sendnth, ms, defines=defines)
        tb.prepare()
        tb.run(steps_rqd)

        # Confirm that our data is correct.
        self.assertEqual(len(tb.out_samples), len(data))
        rffts = [tb.out_samples[N*i: N*(i+1)] for i in range(N_data_sets)]
        # Compare the FFT to that generated by numpy
        # The FFT from our DUT is divided by N to prevent overflow so we do the
        # same to the numpy output.
        effts = [[x/N for x in fft.fft(data_set)] for data_set in data_sets]
        i = 0
        self.assertEqual(len(rffts), len(effts))
        max_delta = 0.02
        for rfft, efft in zip(rffts, effts):
            self.assertEqual(len(rfft), len(efft))
            for e,r in zip(efft, rfft):
                delta = abs(r-e)
                self.assertTrue(delta < max_delta)
        # Compare ms
        for r, e in zip(tb.out_ms, ms):
            self.assertEqual(r, e)

if __name__ == '__main__':
    config.setup_logging(logging.DEBUG)
    unittest.main()
